[Fmha] Sparse MLA decode kernel selection heuristics#2836
[Fmha] Sparse MLA decode kernel selection heuristics#2836bkryu merged 4 commits intoflashinfer-ai:mainfrom
Conversation
…lashinfer-ai#885 Port heuristic improvements from trtllm-gen MR flashinfer-ai#885 to FlashInfer's trtllm-gen FMHA kernel selection (SM100/SM103 sparse MLA decode): 1. New `selectSparseMlaGenerationKernel()` separates sparse MLA selection from the non-sparse path with tuned heuristics: - numHeadsQ <= 32 (SwapsMmaAb): batch-aware tileSizeQ halving at batch=1 (2x more head-splitting CTAs when GPU is under-utilized); threshold batchSize * maxNumCtasPerSeqKv <= MP/8. - numHeadsQ >= 64 (KeepsMmaAb): batch-aware 1CTA/2CTA selection for numHeadsQ=128; threshold batchSize * numCtasPerToken * 8 > MP. - numHeadsQ=64 now uses KeepsMmaAb (tileSizeQ=64) instead of SwapsMmaAb (tileSizeQ=16), yielding 1.35-2.74x speedup across all batch sizes. 2. CgaSmemReduction guard scoped to MLA kernels: only suppress for tileSizeQ >= 32 (headDimQk=576 exceeds smem limit); non-MLA kernels are unaffected. 3. Fix kernel re-selection deadlocks by guarding mMultiCtasKvMode assignments with !mSelectNewKernel: - SwapsMmaAb path: preserve CgaSmemReduction upgrade on re-selection. - KeepsMmaAb path: preserve Disabled mode when numCtasPerSeqKv==1 (small topK), avoiding infinite loop. Also add benchmark script: benchmarks/bench_sparse_mla.py sweeping batch=[1,32,128,512], seqlenKv=[1k-32k], numHeadsQ=[16,32,64,128], dtype=[bf16,e4m3] for DeepSeek-V3 sparse MLA config. Update TRTLLM_GEN_FMHA cubin artifact hash to e7afc4134b (new kernels required for numHeadsQ=64 1CTA KeepsMmaAb sparse variants). Validated: 4512/4512 test_trtllm_gen_mla.py tests pass (3168 skipped).
Replace unbounded while loop in run() with a bounded for loop (kMaxKernelSelectionPasses=4) to make termination explicit and verifiable. Each re-selection trigger fires at most once, so the sequence always converges within 3 passes in practice. Also refactor selectSparseMlaGenerationKernel to remove the !mSelectNewKernel guard pattern on mMultiCtasKvMode assignments, replacing it with an isGmemReduction() check that is semantically equivalent but clearer: it preserves any Disabled mode set by computeCtaAndClusterConfig on re-entry instead of unconditionally overwriting it. Extract buildLaunchConfig() and setNonPortableClusterIfNeeded() as private helpers to eliminate repeated inline setup in run(). Verified no regression across 4512 MLA + 20832 GQA/context tests, and benchmarked GQA (headsQPerKv=4/8/16) and dense MLA decode across batch=[1,4,16,64,256,512], seq=[1024,4096,16384], dtype=[bf16,fp8]: max absolute delta <10us at all reliably-measurable latencies. Add benchmarks/bench_decode_regression.py covering GQA and MLA decode with the above sweep grid for future regression comparisons.
📝 WalkthroughWalkthroughThis PR updates TRTLLM FMHA kernel artifacts and refactors kernel selection logic. Changes include replacing an unbounded while loop with a bounded for loop (max 4 passes), restructuring CgaSmemReduction fallback handling, and adjusting MLA kernel selection heuristics. A new kernel parameter field Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the performance and stability of sparse MLA decode operations within the FlashInfer framework. It introduces refined kernel selection heuristics, improves code structure, and updates necessary artifacts to ensure compatibility and optimal performance on SM100/SM103 architectures. The changes aim to address under-utilization issues and prevent potential shared memory overruns, resulting in significant speedups for specific configurations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
e36f412 to
e2b5819
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces significant performance improvements for sparse MLA decode on SM100/SM103 architectures by porting and refining kernel selection heuristics. The introduction of selectSparseMlaGenerationKernel cleanly separates the new logic, and the heuristics for different head counts and batch sizes are well-documented and appear effective based on the provided benchmarks.
The refactoring in run() is a major improvement for code quality and safety, replacing an unbounded while loop with a bounded for loop and extracting helper functions to improve readability. The added benchmark script is also a valuable contribution for future performance tracking.
Overall, this is a high-quality contribution that improves both performance and maintainability. I have one minor suggestion for cleanup in the new benchmark script.
I am having trouble creating individual review comments. Click here to see my feedback.
benchmarks/bench_sparse_mla.py (242)
The done variable, initialized on line 207, is incremented here but its value is never used. The progress is already tracked and printed using len(results) and total. This variable can be removed for code clarity.
|
/bot run |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)
517-518: Remove extraneous semicolon.There's a stray semicolon on what would be line 518 (after the
maxNumCtasPerSeqKvdeclaration).🧹 Proposed fix
int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256); - ;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 517 - 518, Remove the stray semicolon after the declaration of maxNumCtasPerSeqKv; locate the line that declares "int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);" and delete the following extraneous ";" so the statement is a single clean declaration using flashinfer::ceil_div and params.mMaxSeqLenKv.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 517-518: Remove the stray semicolon after the declaration of
maxNumCtasPerSeqKv; locate the line that declares "int const maxNumCtasPerSeqKv
= flashinfer::ceil_div(params.mMaxSeqLenKv, 256);" and delete the following
extraneous ";" so the statement is a single clean declaration using
flashinfer::ceil_div and params.mMaxSeqLenKv.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1a076edc-d3d2-4022-84be-35991e6c7ab1
📒 Files selected for processing (3)
flashinfer/artifacts.pyinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/kernelParams.h
|
[SUCCESS] Pipeline #46582171: 13/20 passed |
bkryu
left a comment
There was a problem hiding this comment.
Both CIs look good to me. Thanks @PerkzZheng
📌 Description
Summary
Improve kernel-selection heuristic in FlashInfer's trtllm-gen FMHA runner (SM100/SM103 sparse MLA decode), and
cleans up the kernel selection loop in
run().Changes
selectSparseMlaGenerationKernel()(new function)Separates sparse MLA selection from the generic generation path with
tuned per-config heuristics:
SwapsMmaAb (tileSizeQ=16); this is the main perf win (see below).
use 2CTA when
batchSize * numCtasPerToken * 8 > MP(avoidsunder-utilizing SM at small batch).
batch=1: use tileSizeQ/2 when
batchSize * maxNumCtasPerSeqKv ≤ MP/8(doubles head-splitting parallelism when GPU is under-utilized).
tileSizeQ≥32 (headDimQk=576 exceeds smem budget); non-MLA paths
unaffected.
run()refactoringwhile (mSelectNewKernel)loop with a boundedforloop (kMaxKernelSelectionPasses=4), making convergenceexplicit and verifiable.
buildLaunchConfig()andsetNonPortableClusterIfNeeded()as private helpers to eliminate duplicated inline setup.
!mSelectNewKernelguards onmMultiCtasKvModeassignmentsin
selectSparseMlaGenerationKernelwith anisGmemReduction()check,preserving
Disabled/CgaSmemReductionmodes set bycomputeCtaAndClusterConfigon re-entry.Artifact
Updates
TRTLLM_GEN_FMHAcubin checksum for new KeepsMmaAb sparsevariants required by the numHeadsQ=64 path.
Benchmark — Sparse MLA decode speedup (SM100, DeepSeek-V3 config)
GPU: B200, topK=128, seqlen=[1k–32k], dtype=[bf16, fp8]
The numHeadsQ=64 improvement (1.35–2.74x) comes entirely from switching
to KeepsMmaAb/tileSizeQ=64 (previously SwapsMmaAb/tileSizeQ=16).
Tests
test_trtllm_gen_mla.py: 4512/4512 passed (3168 skipped, SM100 only)test_trtllm_gen_attention.py: 20832/20832 passed (23544 skipped)🔍 Related Issues
#2797
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
New Features